{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    " \n",
    "# prepare the training set\n",
    "x_data = [1.0, 2.0, 3.0]\n",
    "y_data = [2.0, 4.0, 6.0]\n",
    " \n",
    "# 初始化权重\n",
    "w = 1.0\n",
    " \n",
    "# 前向计算\n",
    "def forward(x):\n",
    "    return x*w\n",
    " \n",
    "# 均方误差 \n",
    "def cost(xs, ys):\n",
    "    cost = 0\n",
    "    for x, y in zip(xs,ys):\n",
    "        y_pred = forward(x)\n",
    "        cost += (y_pred - y)**2\n",
    "    return cost / len(xs)\n",
    " \n",
    "# 计算梯度\n",
    "def gradient(xs,ys):\n",
    "    grad = 0\n",
    "    for x, y in zip(xs,ys):\n",
    "        grad += 2*x*(x*w - y)\n",
    "    return grad / len(xs)\n",
    " \n",
    "epoch_list = []\n",
    "cost_list = []\n",
    "print('predict (before training)', 4, forward(4))\n",
    "for epoch in range(100):\n",
    "    cost_val = cost(x_data, y_data)\n",
    "    grad_val = gradient(x_data, y_data)\n",
    "    w-= 0.01 * grad_val  \n",
    "    print('epoch:', epoch, 'w=', w, 'loss=', cost_val)\n",
    "    epoch_list.append(epoch)\n",
    "    cost_list.append(cost_val)\n",
    " \n",
    "print('predict (after training)', 4, forward(4))\n",
    "plt.plot(epoch_list,cost_list)\n",
    "plt.ylabel('cost')\n",
    "plt.xlabel('epoch')\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict (before training) 4 4.0\n",
      "\tgrad: 1.0 2.0 -2.0\n",
      "\tgrad: 2.0 4.0 -7.84\n",
      "\tgrad: 3.0 6.0 -16.2288\n",
      "progress: 0 w= 1.260688 loss= 4.919240100095999\n",
      "\tgrad: 1.0 2.0 -1.478624\n",
      "\tgrad: 2.0 4.0 -5.796206079999999\n",
      "\tgrad: 3.0 6.0 -11.998146585599997\n",
      "progress: 1 w= 1.453417766656 loss= 2.688769240265834\n",
      "\tgrad: 1.0 2.0 -1.093164466688\n",
      "\tgrad: 2.0 4.0 -4.285204709416961\n",
      "\tgrad: 3.0 6.0 -8.87037374849311\n",
      "progress: 2 w= 1.5959051959019805 loss= 1.4696334962911515\n",
      "\tgrad: 1.0 2.0 -0.8081896081960389\n",
      "\tgrad: 2.0 4.0 -3.1681032641284723\n",
      "\tgrad: 3.0 6.0 -6.557973756745939\n",
      "progress: 3 w= 1.701247862192685 loss= 0.8032755585999681\n",
      "\tgrad: 1.0 2.0 -0.59750427561463\n",
      "\tgrad: 2.0 4.0 -2.3422167604093502\n",
      "\tgrad: 3.0 6.0 -4.848388694047353\n",
      "progress: 4 w= 1.7791289594933983 loss= 0.43905614881022015\n",
      "\tgrad: 1.0 2.0 -0.44174208101320334\n",
      "\tgrad: 2.0 4.0 -1.7316289575717576\n",
      "\tgrad: 3.0 6.0 -3.584471942173538\n",
      "progress: 5 w= 1.836707389300983 loss= 0.2399802903801062\n",
      "\tgrad: 1.0 2.0 -0.3265852213980338\n",
      "\tgrad: 2.0 4.0 -1.2802140678802925\n",
      "\tgrad: 3.0 6.0 -2.650043120512205\n",
      "progress: 6 w= 1.8792758133988885 loss= 0.1311689630744999\n",
      "\tgrad: 1.0 2.0 -0.241448373202223\n",
      "\tgrad: 2.0 4.0 -0.946477622952715\n",
      "\tgrad: 3.0 6.0 -1.9592086795121197\n",
      "progress: 7 w= 1.910747160155559 loss= 0.07169462478267678\n",
      "\tgrad: 1.0 2.0 -0.17850567968888198\n",
      "\tgrad: 2.0 4.0 -0.6997422643804168\n",
      "\tgrad: 3.0 6.0 -1.4484664872674653\n",
      "progress: 8 w= 1.9340143044689266 loss= 0.03918700813247573\n",
      "\tgrad: 1.0 2.0 -0.13197139106214673\n",
      "\tgrad: 2.0 4.0 -0.5173278529636143\n",
      "\tgrad: 3.0 6.0 -1.0708686556346834\n",
      "progress: 9 w= 1.9512159834655312 loss= 0.021418922423117836\n",
      "\tgrad: 1.0 2.0 -0.09756803306893769\n",
      "\tgrad: 2.0 4.0 -0.38246668963023644\n",
      "\tgrad: 3.0 6.0 -0.7917060475345892\n",
      "progress: 10 w= 1.9639333911678687 loss= 0.01170720245384975\n",
      "\tgrad: 1.0 2.0 -0.07213321766426262\n",
      "\tgrad: 2.0 4.0 -0.2827622132439096\n",
      "\tgrad: 3.0 6.0 -0.5853177814148953\n",
      "progress: 11 w= 1.9733355232910992 loss= 0.006398948863435593\n",
      "\tgrad: 1.0 2.0 -0.05332895341780164\n",
      "\tgrad: 2.0 4.0 -0.2090494973977819\n",
      "\tgrad: 3.0 6.0 -0.4327324596134101\n",
      "progress: 12 w= 1.9802866323953892 loss= 0.003497551760830656\n",
      "\tgrad: 1.0 2.0 -0.039426735209221686\n",
      "\tgrad: 2.0 4.0 -0.15455280202014876\n",
      "\tgrad: 3.0 6.0 -0.3199243001817109\n",
      "progress: 13 w= 1.9854256707695 loss= 0.001911699652671057\n",
      "\tgrad: 1.0 2.0 -0.02914865846100012\n",
      "\tgrad: 2.0 4.0 -0.11426274116712065\n",
      "\tgrad: 3.0 6.0 -0.2365238742159388\n",
      "progress: 14 w= 1.9892250235079405 loss= 0.0010449010656399273\n",
      "\tgrad: 1.0 2.0 -0.021549952984118992\n",
      "\tgrad: 2.0 4.0 -0.08447581569774698\n",
      "\tgrad: 3.0 6.0 -0.17486493849433593\n",
      "progress: 15 w= 1.9920339305797026 loss= 0.0005711243580809696\n",
      "\tgrad: 1.0 2.0 -0.015932138840594856\n",
      "\tgrad: 2.0 4.0 -0.062453984255132156\n",
      "\tgrad: 3.0 6.0 -0.12927974740812687\n",
      "progress: 16 w= 1.994110589284741 loss= 0.0003121664271570621\n",
      "\tgrad: 1.0 2.0 -0.011778821430517894\n",
      "\tgrad: 2.0 4.0 -0.046172980007630926\n",
      "\tgrad: 3.0 6.0 -0.09557806861579543\n",
      "progress: 17 w= 1.9956458879852805 loss= 0.0001706246229305199\n",
      "\tgrad: 1.0 2.0 -0.008708224029438938\n",
      "\tgrad: 2.0 4.0 -0.03413623819540135\n",
      "\tgrad: 3.0 6.0 -0.07066201306448505\n",
      "progress: 18 w= 1.9967809527381737 loss= 9.326038746484765e-05\n",
      "\tgrad: 1.0 2.0 -0.006438094523652627\n",
      "\tgrad: 2.0 4.0 -0.02523733053271826\n",
      "\tgrad: 3.0 6.0 -0.052241274202728505\n",
      "progress: 19 w= 1.9976201197307648 loss= 5.097447086306101e-05\n",
      "\tgrad: 1.0 2.0 -0.004759760538470381\n",
      "\tgrad: 2.0 4.0 -0.01865826131080439\n",
      "\tgrad: 3.0 6.0 -0.03862260091336722\n",
      "progress: 20 w= 1.998240525958391 loss= 2.7861740127856012e-05\n",
      "\tgrad: 1.0 2.0 -0.0035189480832178432\n",
      "\tgrad: 2.0 4.0 -0.01379427648621423\n",
      "\tgrad: 3.0 6.0 -0.028554152326460525\n",
      "progress: 21 w= 1.99869919972735 loss= 1.5228732143933469e-05\n",
      "\tgrad: 1.0 2.0 -0.002601600545300009\n",
      "\tgrad: 2.0 4.0 -0.01019827413757568\n",
      "\tgrad: 3.0 6.0 -0.021110427464781978\n",
      "progress: 22 w= 1.9990383027488265 loss= 8.323754426231206e-06\n",
      "\tgrad: 1.0 2.0 -0.001923394502346909\n",
      "\tgrad: 2.0 4.0 -0.007539706449199102\n",
      "\tgrad: 3.0 6.0 -0.01560719234984198\n",
      "progress: 23 w= 1.9992890056818404 loss= 4.549616284094891e-06\n",
      "\tgrad: 1.0 2.0 -0.0014219886363191492\n",
      "\tgrad: 2.0 4.0 -0.005574195454370212\n",
      "\tgrad: 3.0 6.0 -0.011538584590544687\n",
      "progress: 24 w= 1.999474353368653 loss= 2.486739429417538e-06\n",
      "\tgrad: 1.0 2.0 -0.0010512932626940419\n",
      "\tgrad: 2.0 4.0 -0.004121069589761106\n",
      "\tgrad: 3.0 6.0 -0.008530614050808794\n",
      "progress: 25 w= 1.9996113831376856 loss= 1.3592075910762856e-06\n",
      "\tgrad: 1.0 2.0 -0.0007772337246287897\n",
      "\tgrad: 2.0 4.0 -0.0030467562005451754\n",
      "\tgrad: 3.0 6.0 -0.006306785335127074\n",
      "progress: 26 w= 1.9997126908902887 loss= 7.429187207079447e-07\n",
      "\tgrad: 1.0 2.0 -0.0005746182194226179\n",
      "\tgrad: 2.0 4.0 -0.002252503420136165\n",
      "\tgrad: 3.0 6.0 -0.00466268207967957\n",
      "progress: 27 w= 1.9997875889274812 loss= 4.060661735575354e-07\n",
      "\tgrad: 1.0 2.0 -0.0004248221450375844\n",
      "\tgrad: 2.0 4.0 -0.0016653028085471533\n",
      "\tgrad: 3.0 6.0 -0.0034471768136938863\n",
      "progress: 28 w= 1.9998429619451539 loss= 2.2194855602869353e-07\n",
      "\tgrad: 1.0 2.0 -0.00031407610969225175\n",
      "\tgrad: 2.0 4.0 -0.0012311783499932005\n",
      "\tgrad: 3.0 6.0 -0.0025485391844828342\n",
      "progress: 29 w= 1.9998838998815958 loss= 1.213131374411496e-07\n",
      "\tgrad: 1.0 2.0 -0.00023220023680847746\n",
      "\tgrad: 2.0 4.0 -0.0009102249282886277\n",
      "\tgrad: 3.0 6.0 -0.0018841656015560204\n",
      "progress: 30 w= 1.9999141657892625 loss= 6.630760559646474e-08\n",
      "\tgrad: 1.0 2.0 -0.00017166842147497974\n",
      "\tgrad: 2.0 4.0 -0.0006729402121816719\n",
      "\tgrad: 3.0 6.0 -0.0013929862392156878\n",
      "progress: 31 w= 1.9999365417379913 loss= 3.624255915449335e-08\n",
      "\tgrad: 1.0 2.0 -0.0001269165240174175\n",
      "\tgrad: 2.0 4.0 -0.0004975127741477792\n",
      "\tgrad: 3.0 6.0 -0.0010298514424817995\n",
      "progress: 32 w= 1.9999530845453979 loss= 1.9809538924707548e-08\n",
      "\tgrad: 1.0 2.0 -9.383090920422887e-05\n",
      "\tgrad: 2.0 4.0 -0.00036781716408107457\n",
      "\tgrad: 3.0 6.0 -0.0007613815296476645\n",
      "progress: 33 w= 1.9999653148414271 loss= 1.0827542027017377e-08\n",
      "\tgrad: 1.0 2.0 -6.937031714571162e-05\n",
      "\tgrad: 2.0 4.0 -0.0002719316432120422\n",
      "\tgrad: 3.0 6.0 -0.0005628985014531906\n",
      "progress: 34 w= 1.999974356846045 loss= 5.9181421028034105e-09\n",
      "\tgrad: 1.0 2.0 -5.1286307909848006e-05\n",
      "\tgrad: 2.0 4.0 -0.00020104232700646207\n",
      "\tgrad: 3.0 6.0 -0.0004161576169003922\n",
      "progress: 35 w= 1.9999810417085633 loss= 3.2347513278475087e-09\n",
      "\tgrad: 1.0 2.0 -3.7916582873442906e-05\n",
      "\tgrad: 2.0 4.0 -0.0001486330048638962\n",
      "\tgrad: 3.0 6.0 -0.0003076703200690645\n",
      "progress: 36 w= 1.9999859839076413 loss= 1.7680576050779005e-09\n",
      "\tgrad: 1.0 2.0 -2.8032184717474706e-05\n",
      "\tgrad: 2.0 4.0 -0.0001098861640933535\n",
      "\tgrad: 3.0 6.0 -0.00022746435967313516\n",
      "progress: 37 w= 1.9999896377347262 loss= 9.6638887447731e-10\n",
      "\tgrad: 1.0 2.0 -2.0724530547688857e-05\n",
      "\tgrad: 2.0 4.0 -8.124015974608767e-05\n",
      "\tgrad: 3.0 6.0 -0.00016816713067413502\n",
      "progress: 38 w= 1.999992339052936 loss= 5.282109892545845e-10\n",
      "\tgrad: 1.0 2.0 -1.5321894128117464e-05\n",
      "\tgrad: 2.0 4.0 -6.006182498197177e-05\n",
      "\tgrad: 3.0 6.0 -0.00012432797771566584\n",
      "progress: 39 w= 1.9999943361699042 loss= 2.887107421958329e-10\n",
      "\tgrad: 1.0 2.0 -1.1327660191629008e-05\n",
      "\tgrad: 2.0 4.0 -4.4404427951505454e-05\n",
      "\tgrad: 3.0 6.0 -9.191716585732479e-05\n",
      "progress: 40 w= 1.9999958126624442 loss= 1.5780416225633037e-10\n",
      "\tgrad: 1.0 2.0 -8.37467511161094e-06\n",
      "\tgrad: 2.0 4.0 -3.282872643772805e-05\n",
      "\tgrad: 3.0 6.0 -6.795546372551087e-05\n",
      "progress: 41 w= 1.999996904251097 loss= 8.625295142578772e-11\n",
      "\tgrad: 1.0 2.0 -6.191497806007362e-06\n",
      "\tgrad: 2.0 4.0 -2.4270671399762023e-05\n",
      "\tgrad: 3.0 6.0 -5.0240289795056015e-05\n",
      "progress: 42 w= 1.999997711275687 loss= 4.71443308235547e-11\n",
      "\tgrad: 1.0 2.0 -4.5774486259198e-06\n",
      "\tgrad: 2.0 4.0 -1.794359861406747e-05\n",
      "\tgrad: 3.0 6.0 -3.714324913239864e-05\n",
      "progress: 43 w= 1.9999983079186507 loss= 2.5768253628059826e-11\n",
      "\tgrad: 1.0 2.0 -3.3841626985164908e-06\n",
      "\tgrad: 2.0 4.0 -1.326591777761621e-05\n",
      "\tgrad: 3.0 6.0 -2.7460449796734565e-05\n",
      "progress: 44 w= 1.9999987490239537 loss= 1.4084469615916932e-11\n",
      "\tgrad: 1.0 2.0 -2.5019520926150562e-06\n",
      "\tgrad: 2.0 4.0 -9.807652203264183e-06\n",
      "\tgrad: 3.0 6.0 -2.0301840059744336e-05\n",
      "progress: 45 w= 1.9999990751383971 loss= 7.698320862431846e-12\n",
      "\tgrad: 1.0 2.0 -1.8497232057157476e-06\n",
      "\tgrad: 2.0 4.0 -7.250914967116273e-06\n",
      "\tgrad: 3.0 6.0 -1.5009393983689279e-05\n",
      "progress: 46 w= 1.9999993162387186 loss= 4.20776540913866e-12\n",
      "\tgrad: 1.0 2.0 -1.3675225627451937e-06\n",
      "\tgrad: 2.0 4.0 -5.3606884460322135e-06\n",
      "\tgrad: 3.0 6.0 -1.109662508014253e-05\n",
      "progress: 47 w= 1.9999994944870796 loss= 2.299889814334344e-12\n",
      "\tgrad: 1.0 2.0 -1.0110258408246864e-06\n",
      "\tgrad: 2.0 4.0 -3.963221296032771e-06\n",
      "\tgrad: 3.0 6.0 -8.20386808086937e-06\n",
      "progress: 48 w= 1.9999996262682318 loss= 1.2570789110540446e-12\n",
      "\tgrad: 1.0 2.0 -7.474635363990956e-07\n",
      "\tgrad: 2.0 4.0 -2.930057062755509e-06\n",
      "\tgrad: 3.0 6.0 -6.065218119744031e-06\n",
      "progress: 49 w= 1.999999723695619 loss= 6.870969979249939e-13\n",
      "\tgrad: 1.0 2.0 -5.526087618612507e-07\n",
      "\tgrad: 2.0 4.0 -2.166226346744793e-06\n",
      "\tgrad: 3.0 6.0 -4.484088535150477e-06\n",
      "progress: 50 w= 1.9999997957248556 loss= 3.7555501141274804e-13\n",
      "\tgrad: 1.0 2.0 -4.08550288710785e-07\n",
      "\tgrad: 2.0 4.0 -1.6015171322436572e-06\n",
      "\tgrad: 3.0 6.0 -3.3151404608133817e-06\n",
      "progress: 51 w= 1.9999998489769344 loss= 2.052716967104274e-13\n",
      "\tgrad: 1.0 2.0 -3.020461312175371e-07\n",
      "\tgrad: 2.0 4.0 -1.1840208351543424e-06\n",
      "\tgrad: 3.0 6.0 -2.4509231284497446e-06\n",
      "progress: 52 w= 1.9999998883468353 loss= 1.1219786256679713e-13\n",
      "\tgrad: 1.0 2.0 -2.2330632942768602e-07\n",
      "\tgrad: 2.0 4.0 -8.753608113920563e-07\n",
      "\tgrad: 3.0 6.0 -1.811996877876254e-06\n",
      "progress: 53 w= 1.9999999174534755 loss= 6.132535848018759e-14\n",
      "\tgrad: 1.0 2.0 -1.6509304900935717e-07\n",
      "\tgrad: 2.0 4.0 -6.471647520100987e-07\n",
      "\tgrad: 3.0 6.0 -1.3396310407642886e-06\n",
      "progress: 54 w= 1.999999938972364 loss= 3.351935118167793e-14\n",
      "\tgrad: 1.0 2.0 -1.220552721115098e-07\n",
      "\tgrad: 2.0 4.0 -4.784566662863199e-07\n",
      "\tgrad: 3.0 6.0 -9.904052991061008e-07\n",
      "progress: 55 w= 1.9999999548815364 loss= 1.8321081844499955e-14\n",
      "\tgrad: 1.0 2.0 -9.023692726373156e-08\n",
      "\tgrad: 2.0 4.0 -3.5372875473171916e-07\n",
      "\tgrad: 3.0 6.0 -7.322185204827747e-07\n",
      "progress: 56 w= 1.9999999666433785 loss= 1.0013977760018664e-14\n",
      "\tgrad: 1.0 2.0 -6.671324292994996e-08\n",
      "\tgrad: 2.0 4.0 -2.615159129248923e-07\n",
      "\tgrad: 3.0 6.0 -5.413379398078177e-07\n",
      "progress: 57 w= 1.9999999753390494 loss= 5.473462367088053e-15\n",
      "\tgrad: 1.0 2.0 -4.932190122985958e-08\n",
      "\tgrad: 2.0 4.0 -1.9334185274999527e-07\n",
      "\tgrad: 3.0 6.0 -4.002176350326181e-07\n",
      "progress: 58 w= 1.9999999817678633 loss= 2.991697274308627e-15\n",
      "\tgrad: 1.0 2.0 -3.6464273378555845e-08\n",
      "\tgrad: 2.0 4.0 -1.429399514307761e-07\n",
      "\tgrad: 3.0 6.0 -2.9588569994132286e-07\n",
      "progress: 59 w= 1.9999999865207625 loss= 1.6352086111474931e-15\n",
      "\tgrad: 1.0 2.0 -2.6958475007887728e-08\n",
      "\tgrad: 2.0 4.0 -1.0567722164012139e-07\n",
      "\tgrad: 3.0 6.0 -2.1875184863517916e-07\n",
      "progress: 60 w= 1.999999990034638 loss= 8.937759877335403e-16\n",
      "\tgrad: 1.0 2.0 -1.993072418216002e-08\n",
      "\tgrad: 2.0 4.0 -7.812843882959442e-08\n",
      "\tgrad: 3.0 6.0 -1.617258700292723e-07\n",
      "progress: 61 w= 1.9999999926324883 loss= 4.885220495987371e-16\n",
      "\tgrad: 1.0 2.0 -1.473502342363986e-08\n",
      "\tgrad: 2.0 4.0 -5.7761292637792394e-08\n",
      "\tgrad: 3.0 6.0 -1.195658771990793e-07\n",
      "progress: 62 w= 1.99999999455311 loss= 2.670175009618106e-16\n",
      "\tgrad: 1.0 2.0 -1.0893780100218464e-08\n",
      "\tgrad: 2.0 4.0 -4.270361841918202e-08\n",
      "\tgrad: 3.0 6.0 -8.839649012770678e-08\n",
      "progress: 63 w= 1.9999999959730488 loss= 1.4594702493172377e-16\n",
      "\tgrad: 1.0 2.0 -8.05390243385773e-09\n",
      "\tgrad: 2.0 4.0 -3.1571296688071016e-08\n",
      "\tgrad: 3.0 6.0 -6.53525820126788e-08\n",
      "progress: 64 w= 1.9999999970228268 loss= 7.977204100704301e-17\n",
      "\tgrad: 1.0 2.0 -5.9543463493128e-09\n",
      "\tgrad: 2.0 4.0 -2.334103754719763e-08\n",
      "\tgrad: 3.0 6.0 -4.8315948575350376e-08\n",
      "progress: 65 w= 1.9999999977989402 loss= 4.360197735196887e-17\n",
      "\tgrad: 1.0 2.0 -4.402119557767037e-09\n",
      "\tgrad: 2.0 4.0 -1.725630838222969e-08\n",
      "\tgrad: 3.0 6.0 -3.5720557178819945e-08\n",
      "progress: 66 w= 1.9999999983727301 loss= 2.3832065197304227e-17\n",
      "\tgrad: 1.0 2.0 -3.254539748809293e-09\n",
      "\tgrad: 2.0 4.0 -1.2757796596929438e-08\n",
      "\tgrad: 3.0 6.0 -2.6408640607655798e-08\n",
      "progress: 67 w= 1.9999999987969397 loss= 1.3026183953845832e-17\n",
      "\tgrad: 1.0 2.0 -2.406120636067044e-09\n",
      "\tgrad: 2.0 4.0 -9.431992964437086e-09\n",
      "\tgrad: 3.0 6.0 -1.9524227568012975e-08\n",
      "progress: 68 w= 1.999999999110563 loss= 7.11988308874388e-18\n",
      "\tgrad: 1.0 2.0 -1.7788739370416806e-09\n",
      "\tgrad: 2.0 4.0 -6.97318647269185e-09\n",
      "\tgrad: 3.0 6.0 -1.4434496264925656e-08\n",
      "progress: 69 w= 1.9999999993424284 loss= 3.89160224698574e-18\n",
      "\tgrad: 1.0 2.0 -1.3151431055291596e-09\n",
      "\tgrad: 2.0 4.0 -5.155360582875801e-09\n",
      "\tgrad: 3.0 6.0 -1.067159693945996e-08\n",
      "progress: 70 w= 1.9999999995138495 loss= 2.1270797208746147e-18\n",
      "\tgrad: 1.0 2.0 -9.72300906454393e-10\n",
      "\tgrad: 2.0 4.0 -3.811418736177075e-09\n",
      "\tgrad: 3.0 6.0 -7.88963561149103e-09\n",
      "progress: 71 w= 1.9999999996405833 loss= 1.1626238773828175e-18\n",
      "\tgrad: 1.0 2.0 -7.18833437218791e-10\n",
      "\tgrad: 2.0 4.0 -2.8178277489132597e-09\n",
      "\tgrad: 3.0 6.0 -5.832902161273523e-09\n",
      "progress: 72 w= 1.999999999734279 loss= 6.354692062078993e-19\n",
      "\tgrad: 1.0 2.0 -5.314420015167798e-10\n",
      "\tgrad: 2.0 4.0 -2.0832526814729135e-09\n",
      "\tgrad: 3.0 6.0 -4.31233715403323e-09\n",
      "progress: 73 w= 1.9999999998035491 loss= 3.4733644793346653e-19\n",
      "\tgrad: 1.0 2.0 -3.92901711165905e-10\n",
      "\tgrad: 2.0 4.0 -1.5401742103904326e-09\n",
      "\tgrad: 3.0 6.0 -3.188159070077745e-09\n",
      "progress: 74 w= 1.9999999998547615 loss= 1.8984796531526204e-19\n",
      "\tgrad: 1.0 2.0 -2.9047697580608656e-10\n",
      "\tgrad: 2.0 4.0 -1.1386696030513122e-09\n",
      "\tgrad: 3.0 6.0 -2.3570478902001923e-09\n",
      "progress: 75 w= 1.9999999998926234 loss= 1.0376765851119951e-19\n",
      "\tgrad: 1.0 2.0 -2.1475310418850313e-10\n",
      "\tgrad: 2.0 4.0 -8.418314934033333e-10\n",
      "\tgrad: 3.0 6.0 -1.7425900722400911e-09\n",
      "progress: 76 w= 1.9999999999206153 loss= 5.671751114309842e-20\n",
      "\tgrad: 1.0 2.0 -1.5876944203796484e-10\n",
      "\tgrad: 2.0 4.0 -6.223768167501476e-10\n",
      "\tgrad: 3.0 6.0 -1.2883241140571045e-09\n",
      "progress: 77 w= 1.9999999999413098 loss= 3.100089617511693e-20\n",
      "\tgrad: 1.0 2.0 -1.17380327679939e-10\n",
      "\tgrad: 2.0 4.0 -4.601314884666863e-10\n",
      "\tgrad: 3.0 6.0 -9.524754318590567e-10\n",
      "progress: 78 w= 1.9999999999566096 loss= 1.6944600977692705e-20\n",
      "\tgrad: 1.0 2.0 -8.678080476443029e-11\n",
      "\tgrad: 2.0 4.0 -3.4018121652934497e-10\n",
      "\tgrad: 3.0 6.0 -7.041780492045291e-10\n",
      "progress: 79 w= 1.9999999999679208 loss= 9.2616919156479e-21\n",
      "\tgrad: 1.0 2.0 -6.415845632545825e-11\n",
      "\tgrad: 2.0 4.0 -2.5150193039280566e-10\n",
      "\tgrad: 3.0 6.0 -5.206075570640678e-10\n",
      "progress: 80 w= 1.9999999999762834 loss= 5.062350511130293e-21\n",
      "\tgrad: 1.0 2.0 -4.743316850408519e-11\n",
      "\tgrad: 2.0 4.0 -1.8593837580738182e-10\n",
      "\tgrad: 3.0 6.0 -3.8489211817704927e-10\n",
      "progress: 81 w= 1.999999999982466 loss= 2.7669155644059242e-21\n",
      "\tgrad: 1.0 2.0 -3.5067948545020045e-11\n",
      "\tgrad: 2.0 4.0 -1.3746692673066718e-10\n",
      "\tgrad: 3.0 6.0 -2.845563784603655e-10\n",
      "progress: 82 w= 1.9999999999870368 loss= 1.5124150106147723e-21\n",
      "\tgrad: 1.0 2.0 -2.5926372160256506e-11\n",
      "\tgrad: 2.0 4.0 -1.0163070385260653e-10\n",
      "\tgrad: 3.0 6.0 -2.1037571684701106e-10\n",
      "progress: 83 w= 1.999999999990416 loss= 8.26683933105326e-22\n",
      "\tgrad: 1.0 2.0 -1.9167778475548403e-11\n",
      "\tgrad: 2.0 4.0 -7.51381179497912e-11\n",
      "\tgrad: 3.0 6.0 -1.5553425214420713e-10\n",
      "progress: 84 w= 1.9999999999929146 loss= 4.518126871054872e-22\n",
      "\tgrad: 1.0 2.0 -1.4170886686315498e-11\n",
      "\tgrad: 2.0 4.0 -5.555023108172463e-11\n",
      "\tgrad: 3.0 6.0 -1.1499068364173581e-10\n",
      "progress: 85 w= 1.9999999999947617 loss= 2.469467919185614e-22\n",
      "\tgrad: 1.0 2.0 -1.0476508549572827e-11\n",
      "\tgrad: 2.0 4.0 -4.106759377009439e-11\n",
      "\tgrad: 3.0 6.0 -8.500933290633839e-11\n",
      "progress: 86 w= 1.9999999999961273 loss= 1.349840097651456e-22\n",
      "\tgrad: 1.0 2.0 -7.745359908994942e-12\n",
      "\tgrad: 2.0 4.0 -3.036149109902908e-11\n",
      "\tgrad: 3.0 6.0 -6.285105769165966e-11\n",
      "progress: 87 w= 1.999999999997137 loss= 7.376551550022107e-23\n",
      "\tgrad: 1.0 2.0 -5.726086271806707e-12\n",
      "\tgrad: 2.0 4.0 -2.2446045022661565e-11\n",
      "\tgrad: 3.0 6.0 -4.646416584819235e-11\n",
      "progress: 88 w= 1.9999999999978835 loss= 4.031726170507742e-23\n",
      "\tgrad: 1.0 2.0 -4.233058348290797e-12\n",
      "\tgrad: 2.0 4.0 -1.659294923683774e-11\n",
      "\tgrad: 3.0 6.0 -3.4351188560322043e-11\n",
      "progress: 89 w= 1.9999999999984353 loss= 2.2033851437431755e-23\n",
      "\tgrad: 1.0 2.0 -3.1294966618133913e-12\n",
      "\tgrad: 2.0 4.0 -1.226752033289813e-11\n",
      "\tgrad: 3.0 6.0 -2.539835008974478e-11\n",
      "progress: 90 w= 1.9999999999988431 loss= 1.2047849775995315e-23\n",
      "\tgrad: 1.0 2.0 -2.3137047833188262e-12\n",
      "\tgrad: 2.0 4.0 -9.070078021977679e-12\n",
      "\tgrad: 3.0 6.0 -1.8779644506139448e-11\n",
      "progress: 91 w= 1.9999999999991447 loss= 6.5840863393251405e-24\n",
      "\tgrad: 1.0 2.0 -1.7106316363424412e-12\n",
      "\tgrad: 2.0 4.0 -6.7057470687359455e-12\n",
      "\tgrad: 3.0 6.0 -1.3882228699912957e-11\n",
      "progress: 92 w= 1.9999999999993676 loss= 3.5991747246272455e-24\n",
      "\tgrad: 1.0 2.0 -1.2647660696529783e-12\n",
      "\tgrad: 2.0 4.0 -4.957811938766099e-12\n",
      "\tgrad: 3.0 6.0 -1.0263789818054647e-11\n",
      "progress: 93 w= 1.9999999999995324 loss= 1.969312363793734e-24\n",
      "\tgrad: 1.0 2.0 -9.352518759442319e-13\n",
      "\tgrad: 2.0 4.0 -3.666400516522117e-12\n",
      "\tgrad: 3.0 6.0 -7.58859641791787e-12\n",
      "progress: 94 w= 1.9999999999996543 loss= 1.0761829795642296e-24\n",
      "\tgrad: 1.0 2.0 -6.914468997365475e-13\n",
      "\tgrad: 2.0 4.0 -2.7107205369247822e-12\n",
      "\tgrad: 3.0 6.0 -5.611511255665391e-12\n",
      "progress: 95 w= 1.9999999999997444 loss= 5.875191475205477e-25\n",
      "\tgrad: 1.0 2.0 -5.111466805374221e-13\n",
      "\tgrad: 2.0 4.0 -2.0037305148434825e-12\n",
      "\tgrad: 3.0 6.0 -4.1460168631601846e-12\n",
      "progress: 96 w= 1.999999999999811 loss= 3.2110109830478153e-25\n",
      "\tgrad: 1.0 2.0 -3.779199175824033e-13\n",
      "\tgrad: 2.0 4.0 -1.4814816040598089e-12\n",
      "\tgrad: 3.0 6.0 -3.064215547965432e-12\n",
      "progress: 97 w= 1.9999999999998603 loss= 1.757455879087579e-25\n",
      "\tgrad: 1.0 2.0 -2.793321129956894e-13\n",
      "\tgrad: 2.0 4.0 -1.0942358130705543e-12\n",
      "\tgrad: 3.0 6.0 -2.2648549702353193e-12\n",
      "progress: 98 w= 1.9999999999998967 loss= 9.608404711682446e-26\n",
      "\tgrad: 1.0 2.0 -2.0650148258027912e-13\n",
      "\tgrad: 2.0 4.0 -8.100187187665142e-13\n",
      "\tgrad: 3.0 6.0 -1.6786572132332367e-12\n",
      "progress: 99 w= 1.9999999999999236 loss= 5.250973729513143e-26\n",
      "predict (after training) 4 7.9999999999996945\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEGCAYAAABvtY4XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAVTUlEQVR4nO3de4ydd33n8fd35hzPGSce20mmwbEdHC7LQtnGCbOUbNoum7QIKCqsuARaKOoiWd1l1bBC223UVqtW+8eudkXZqojFCmxDm6VASFoUcU9paCRyGQcn5MItIQEHB08IviTxZWb83T+e54yPjd1MbD9zPL/n/ZJGnnOZ8/s9euyPf/N9fs/vF5mJJKk8I8PugCSpGQa8JBXKgJekQhnwklQoA16SCtUZdgcGnXfeeblp06Zhd0OSlo1t27Y9kZmTx3vtjAr4TZs2MT09PexuSNKyERGPnug1SzSSVCgDXpIK1WiJJiIeAfYB88BcZk412Z4k6YilqMH/m8x8YgnakSQNsEQjSYVqOuAT+FJEbIuILcd7Q0RsiYjpiJiemZlpuDuS1B5NB/wvZealwOuA90bErxz7hszcmplTmTk1OXncqZySpJPQaMBn5mP1n7uAm4BXNtHOn9/yXW79jqN/SRrUWMBHxFkRsar/PfAa4L4m2vrIrQ/xjwa8JB2lyVk05wM3RUS/nf+XmV9ooqGx7igH5uab+GhJWrYaC/jMfBi4uKnPH9TrjHBg9vBSNCVJy0YR0yR73VEOzhnwkjSoiIAf645yYNYSjSQNKiLge90RA16SjlFEwI91RjhoDV6SjlJEwFc1eEfwkjSojIDvjDqLRpKOUUbAd0ecBy9Jxygi4Mc6zqKRpGMVEfDVLBpLNJI0qJCA9yKrJB2riICvbnQ6TGYOuyuSdMYoIuB73eowXK5Ako4oIuDHOqMA3uwkSQOKCPj+CN6pkpJ0RBkB7whekn5GGQHfrQLeEbwkHVFEwI916hKNNztJ0oIiAn5hBG+JRpIWFBLw/WmSjuAlqa+QgHcEL0nHKiLgrcFL0s8qIuCPjOANeEnqKyLgxxZudLJEI0l9RQR8fwR/0BG8JC0oI+D7d7I6gpekBUUEfHc0iLAGL0mDigj4iKg33jbgJamviIAHt+2TpGMVFPBu2ydJg4oJ+LGOI3hJGlRMwPe61uAlaVDjAR8RoxHxjYi4ucl2xrqj3ugkSQOWYgR/NfBg0430OiPe6CRJAxoN+IjYAPw6cG2T7UBdonEEL0kLmh7BfxD4feCEyRsRWyJiOiKmZ2ZmTrqhMUfwknSUxgI+It4A7MrMbf/U+zJza2ZOZebU5OTkSbfnRVZJOlqTI/jLgd+IiEeAvwGuiIi/bqoxb3SSpKM1FvCZeU1mbsjMTcDbgb/PzHc21Z43OknS0YqZB++NTpJ0tM5SNJKZ/wD8Q5NtVLNo5slMIqLJpiRpWShmBN/rjpIJh+YdxUsSFBTwRzbeNuAlCQoK+IVt+7zQKklAQQHfH8EfdAQvSUBBAd8fwXuzkyRVCgx4R/CSBEUFfF2isQYvSUBBAT/WcQQvSYOKCfj+CN4avCRVCgr4egRviUaSgJIC3hKNJB2lmIAf8yKrJB2lmIB3BC9JRysm4Me8yCpJRykn4DsjROC+rJJUKybgI6La9GPOEo0kQUEBD9XNTo7gJalSVMC78bYkHVFYwI96o5Mk1coK+M6os2gkqVZUwI91RzjoRVZJAgoLeEfwknREUQE/5kVWSVpQVMD3uo7gJamvuIC3Bi9JlaICfqwz4o1OklQrKuB7XZcqkKS+sgLeWTSStKCsgK8vsmbmsLsiSUNXVMCPdUY4nDA7b8BLUlEB39942237JKnBgI+IXkTcGRH3RMT9EfEnTbXV11vY1ckLrZLUafCzDwJXZOZTEdEFbouIz2fm7U01ONbt78vqCF6SGgv4rK50PlU/7NZfjRbHxzrVCN4SjSQ1XIOPiNGI2A7sAr6cmXcc5z1bImI6IqZnZmZOqb3ewgjeEo0kNRrwmTmfmZuBDcArI+Llx3nP1sycysypycnJU2rPi6ySdMSSzKLJzN3AV4HXNtlOr+NFVknqa3IWzWRErKm/Hwd+DfhWU+3BYInGEbwkNTmLZh1wXUSMUv1H8qnMvLnB9hhzmqQkLWhyFs29wCVNff7x9DrW4CWpr8g7WR3BS1JxAd8v0TiCl6SiAn6sLtEcsEQjSaUFvBdZJamvqIAfGQlWdEa8yCpJFBbwUN3sdNARvCQVGPBdt+2TJCgw4M8e6/DUwblhd0OShq64gF813mXvAQNekooL+Ileh737Z4fdDUkaukUFfERcHRETUfloRNwdEa9punMnY2K8y94DBrwkLXYE/+8ycy/wGmAt8C7gvzfWq1Mw0euyd78lGklabMBH/efrgb/KzPsHnjujTIx3HMFLEosP+G0R8SWqgP9iRKwCzsjJ5hO9LofmDjtVUlLrLXa54PcAm4GHM/OZiDgH+J3GenUKJsa7AOzdP7uwuqQktdFiR/CXAd/OzN0R8U7gj4A9zXXr5K3uB7xlGkktt9iA/zDwTERcDLwfeAj4eGO9OgUTveqXkj1eaJXUcosN+LnMTOCNwF9k5oeAVc116+RNOIKXJGDxNfh9EXEN1fTIX46IEaDbXLdO3kTvSA1ektpssSP4q4CDVPPhHwc2AP+zsV6dgonx6v8slyuQ1HaLCvg61K8HVkfEG4ADmXmG1uAdwUsSLH6pgrcBdwJvBd4G3BERb2myYyer1x1lRWfEGryk1ltsDf4PgX+ZmbsAImIS+ApwQ1MdOxUuVyBJi6/Bj/TDvfaT5/CzS87lCiRp8SP4L0TEF4FP1I+vAj7XTJdOXTWCN+AltduiAj4z/3NEvBm4vH5qa2be1Fy3Ts3EeJc9BrykllvsCJ7M/AzwmQb7ctpM9DrsePKZYXdDkobqnwz4iNgH5PFeAjIzJxrp1Sly0w9JepaAz8wzcjmCZ7N6vJpFk5lEnJHL1ktS487YmTCnYqLX5dD8YQ7OnZFL1kvSkigz4PvLFXihVVKLNRbwEbExIr4aEQ9ExP0RcXVTbR2rv1yBM2kktdmiZ9GchDng/Zl5d73F37aI+HJmPtBgm4BLBksSNDiCz8ydmXl3/f0+4EFgfVPtDepv+uFyBZLabElq8BGxCbgEuGMp2nMEL0lLEPARcTbVDVLvy8y9x3l9S0RMR8T0zMzMaWnTJYMlqeGAj4guVbhfn5k3Hu89mbk1M6cyc2pycvK0tLuq56YfktTkLJoAPgo8mJkfaKqd4+l1RxnrjDiCl9RqTY7gL6faw/WKiNhef72+wfaO4nIFktqusWmSmXkb1Zo1QzHR6ziLRlKrFXknKziCl6RiA75acMyAl9RexQb8RK/rLBpJrVZuwI93HMFLarVyA75X1eAzj7dfiSSVr9yAH+8yO58cmHVNeEntVG7A91yPRlK7lRvwbvohqeXKDXhH8JJartyA7y8Z7N2sklqq3ICvV5R02z5JbVVuwLvph6SWKzbgF9aEdwQvqaWKDfixzii97ojLFUhqrWIDHlxwTFK7FR3wE72uF1kltVbRAX/e2WPs2ndw2N2QpKEoOuDXremxc/f+YXdDkoai6IBfv2acH+87yNy8C45Jap+iA37d6nHmD6dlGkmtVHTAX7CmB8DOPZZpJLVP4QE/DsBjuw8MuSeStPSKDvh1q+sRvBdaJbVQ0QG/qtdlVa/Dzj2O4CW1T9EBD3DB6nEecwQvqYWKD/h1a3peZJXUSsUH/AVrxvmRF1kltVD5Ab+6x5NPH+LA7PywuyJJS6r4gF+3upoq6YVWSW1TfMD358L/yAutklqmBQFfzYU34CW1TfEB/7zV/YC3RCOpXRoL+Ij4WETsioj7mmpjMcY6o5x39phTJSW1TpMj+L8EXtvg5y/aBWt63uwkqXUaC/jM/BrwZFOf/1xcsHrcWTSSWmfoNfiI2BIR0xExPTMz00gb/Z2dMrORz5ekM9HQAz4zt2bmVGZOTU5ONtLGBavHefrQPHv3zzXy+ZJ0Jhp6wC+FhbnwXmiV1CKtCPh1zoWX1EJNTpP8BPB14CURsSMi3tNUW89m/cII3gutktqj09QHZ+Y7mvrs5+q8s8fojIQ7O0lqlVaUaEZHgvMnepZoJLVKKwIeqjLNjp8a8JLaozUB/9J1q3hg517mDzsXXlI7tCbgL964hmcOzfO9XU8NuyuStCRaFfAA9/xw91D7IUlLpTUBf9G5Z7Gq12H7jt3D7ookLYnWBPzISHDxhjXca8BLaonWBDzAL2xYzbd27nMDbkmt0KqAv3jjGuYOJ/f/aO+wuyJJjWtVwG/2QqukFmlVwJ8/0eN5Ez3usQ4vqQVaFfBQ1eHv3bFn2N2QpMa1LuAv3riG7z/xNLufOTTsrkhSo1oX8P06vKN4SaVrXcD/iw2rAS+0Sipf6wJ+otflhZNneaFVUvFaF/AAU88/h9sffpL9h7zhSVK5Whnw//bS9Tx1cI7P37dz2F2RpMa0MuB/8aJz2HTuSj551w+H3RVJakwrAz4ieOvURu74/pM88sTTw+6OJDWilQEP8OZLNzAS8OltjuIllam1Af+81T3+9T+b5IZtO9zGT1KRWhvwAG+b2siP9x7ka9+dGXZXJOm0a3XAX/nS8znnrBV8youtkgrU6oBf0RnhbVMb+cL9j/P1h34y7O5I0mnV6oAH+L0rX8Smc8/i/Z/azp79s8PujiSdNq0P+JUrOnzwqs38eN9B/vhv7xt2dyTptGl9wEO1hPD7rnwxn73nR/zd9seG3R1JOi0M+Nq/f/ULufTCNVxz4zf53DddwkDS8mfA1zqjI3z4na/gJc9bxX+4/m7+280PMDt/eNjdkqSTZsAPOH+ixye3XMZvX/Z8rr3t+7x96+3c+p0ZDnsjlKRlqNGAj4jXRsS3I+J7EfEHTbZ1uqzojPCnb3w5H7xqM4/+5Gne/bE7+dUP3Mq1//gw3358n2EvadmIzGYCKyJGge8AvwbsAO4C3pGZD5zoZ6ampnJ6erqR/pyMg3PzfP6bj3Pd1x/hGz/YDcBEr8PmC9fygvPOYsPacTasHWftyhWsWbmC1eNdVo6NMt4dpTvqL0eSmhcR2zJz6nivdRps95XA9zLz4boTfwO8EThhwJ9pxjqjvOmS9bzpkvX84CfPcNcjTzL96JNs/+Ee7n70pzx1cO6EPzs6EnRHg+7ICN3OCCMRjI7AaAQRwcgIjEQQ9ftj4Pv+NzHweRGDj57dc3u3pGFau3IFn/rdy0775zYZ8OuBwTUAdgC/eOybImILsAXgwgsvbLA7p+bCc1dy4bkrefMrNgCQmezZP8tju/ez+5nZ6mv/IfYfmufA7Dz7Z+eZnU9m5w8zO3+Y+cNw+HAyn8nhTEiYr397qh8ufC4cefyzD55dPtcfkDRUE71uI5/bZMAvSmZuBbZCVaIZcncWLSJYU5dmJOlM1GSh+DFg48DjDfVzkqQl0GTA3wW8OCIuiogVwNuBzzbYniRpQGMlmsyci4j/CHwRGAU+lpn3N9WeJOlojdbgM/NzwOeabEOSdHxO1pakQhnwklQoA16SCmXAS1KhGluL5mRExAzw6En++HnAE6exO8tBG48Z2nncbTxmaOdxP9djfn5mTh7vhTMq4E9FREyfaMGdUrXxmKGdx93GY4Z2HvfpPGZLNJJUKANekgpVUsBvHXYHhqCNxwztPO42HjO087hP2zEXU4OXJB2tpBG8JGmAAS9JhVr2Ab8cN/Y+GRGxMSK+GhEPRMT9EXF1/fw5EfHliPhu/efaYff1dIuI0Yj4RkTcXD++KCLuqM/5J+vlqIsSEWsi4oaI+FZEPBgRl5V+riPiP9V/t++LiE9ERK/Ecx0RH4uIXRFx38Bzxz23Ufnz+vjvjYhLn0tbyzrg6429PwS8DngZ8I6IeNlwe9WYOeD9mfky4FXAe+tj/QPglsx8MXBL/bg0VwMPDjz+H8CfZeaLgJ8C7xlKr5r1v4EvZOY/By6mOv5iz3VErAd+D5jKzJdTLTH+dso8138JvPaY5050bl8HvLj+2gJ8+Lk0tKwDnoGNvTPzENDf2Ls4mbkzM++uv99H9Q9+PdXxXle/7TrgTUPpYEMiYgPw68C19eMArgBuqN9S4jGvBn4F+ChAZh7KzN0Ufq6pli8fj4gOsBLYSYHnOjO/Bjx5zNMnOrdvBD6elduBNRGxbrFtLfeAP97G3uuH1JclExGbgEuAO4DzM3Nn/dLjwPnD6ldDPgj8PnC4fnwusDsz5+rHJZ7zi4AZ4P/WpalrI+IsCj7XmfkY8L+AH1AF+x5gG+Wf674TndtTyrjlHvCtExFnA58B3peZewdfy2rOazHzXiPiDcCuzNw27L4ssQ5wKfDhzLwEeJpjyjEFnuu1VKPVi4ALgLP42TJGK5zOc7vcA75VG3tHRJcq3K/PzBvrp3/c/5Wt/nPXsPrXgMuB34iIR6jKb1dQ1abX1L/GQ5nnfAewIzPvqB/fQBX4JZ/rXwW+n5kzmTkL3Eh1/ks/130nOrenlHHLPeBbs7F3XXv+KPBgZn5g4KXPAu+uv3838HdL3bemZOY1mbkhMzdRndu/z8zfAr4KvKV+W1HHDJCZjwM/jIiX1E9dCTxAweeaqjTzqohYWf9d7x9z0ed6wInO7WeB365n07wK2DNQynl2mbmsv4DXA98BHgL+cNj9afA4f4nq17Z7ge311+upatK3AN8FvgKcM+y+NnT8rwZurr9/AXAn8D3g08DYsPvXwPFuBqbr8/23wNrSzzXwJ8C3gPuAvwLGSjzXwCeorjPMUv229p4TnVsgqGYKPgR8k2qW0aLbcqkCSSrUci/RSJJOwICXpEIZ8JJUKANekgplwEtSoQx46TSIiFf3V7uUzhQGvCQVyoBXq0TEOyPizojYHhEfqdeafyoi/qxei/yWiJis37s5Im6v1+G+aWCN7hdFxFci4p6IuDsiXlh//NkDa7hfX9+RKQ2NAa/WiIiXAlcBl2fmZmAe+C2qha2mM/PngVuB/1r/yMeB/5KZv0B1F2H/+euBD2XmxcC/ororEaoVPt9HtTfBC6jWUpGGpvPsb5GKcSXwCuCuenA9TrWo02Hgk/V7/hq4sV6TfU1m3lo/fx3w6YhYBazPzJsAMvMAQP15d2bmjvrxdmATcFvjRyWdgAGvNgngusy85qgnI/74mPed7PodBwe+n8d/XxoySzRqk1uAt0TEz8HCPpjPp/p30F+x8DeB2zJzD/DTiPjl+vl3AbdmtZvWjoh4U/0ZYxGxcikPQlosRxhqjcx8ICL+CPhSRIxQreb3XqoNNV5Zv7aLqk4P1bKt/6cO8IeB36mffxfwkYj40/oz3rqEhyEtmqtJqvUi4qnMPHvY/ZBON0s0klQoR/CSVChH8JJUKANekgplwEtSoQx4SSqUAS9Jhfr/ZzNpuA7TGDIAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    " \n",
    "x_data = [1.0, 2.0, 3.0]\n",
    "y_data = [2.0, 4.0, 6.0]\n",
    " \n",
    "w = 1.0\n",
    " \n",
    "def forward(x):\n",
    "    return x*w\n",
    " \n",
    "# 损失函数  \n",
    "def loss(x, y):\n",
    "    y_pred = forward(x)\n",
    "    return (y_pred - y)**2\n",
    " \n",
    "# 梯度 \n",
    "def gradient(x, y):\n",
    "    return 2*x*(x*w - y)\n",
    " \n",
    "epoch_list = []\n",
    "loss_list = []\n",
    "print('predict (before training)', 4, forward(4))\n",
    "for epoch in range(100):\n",
    "    for x,y in zip(x_data, y_data):\n",
    "        grad = gradient(x,y)\n",
    "        w = w - 0.01*grad    # 选择一个样本计算梯度\n",
    "        print(\"\\tgrad:\", x, y,grad)\n",
    "        l = loss(x,y)\n",
    "    print(\"progress:\",epoch,\"w=\",w,\"loss=\",l)\n",
    "    epoch_list.append(epoch)\n",
    "    loss_list.append(l)\n",
    " \n",
    "print('predict (after training)', 4, forward(4))\n",
    "plt.plot(epoch_list,loss_list)\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![alt text](image-1.png)\n",
    "使用所有数据计算梯度性能低，时间复杂度低 ；每次一个样本计算梯度时间复杂度高，性能高，折中采用batch，时间复杂度低性能高"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
